{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(databricks)=\n",
    "# Databricks runtime\n",
    "\n",
    "The databricks runtime runs on a Databricks cluster (and not in the Iguazio cluster). The function raises a pod on MLRun, which communicates with the Databricks cluster. The requests originate in MLRun and all computing is in the Databricks cluster.\n",
    "\n",
    "With the databricks runtime, you can send your local file/code as a string to the job, and use a handler as an endpoint for user code. You can optionally send keyword arguments (kwargs) to \n",
    "this job. \n",
    "\n",
    "You can run the function on:\n",
    "- An existing cluster, by including `DATABRICKS_CLUSTER_ID`\n",
    "- A job compute cluster, created and dedicated for this function only. \n",
    "\n",
    "Params that are not related to a new cluster or an existing cluster:\n",
    "- `timeout_minutes`\n",
    "- `token_key`\n",
    "- `artifact_json_dir` (location where the json file that contains all logged mlrun artifacts is saved, and which is deleted after the run)\n",
    "\n",
    "Params that are related to a new cluster:\n",
    "- `spark_version`\n",
    "- `node_type_id`\n",
    "- `num_workers`\n",
    "\n",
    "## Example of a job compute cluster\n",
    "\n",
    "To create a job compute cluster, omit `DATABRICKS_CLUSTER_ID`, and set the [cluster specs](https://docs.databricks.com/aws/en/reference/jobs-2.0-api#newcluster) by using the task parameters when running the function. For example:\n",
    "   ```\n",
    "   params['task_parameters'] = {'new_cluster_spec': {'node_type_id': 'm5d.large'}, 'number_of_workers': 2, 'timeout_minutes': 15, `token_key`: non-default-value}\n",
    "   ```\n",
    "Do not send variables named `task_parameters` or `context` since these are utilized by the internal processes of the runtime.\n",
    "\n",
    "## Example of running a Databricks job from a local file\n",
    "\n",
    "This example uses an existing cluster: DATABRICKS_CLUSTER_ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import mlrun\n",
    "from mlrun.runtimes.function_reference import FunctionReference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If using a Databricks data store, for example, set the credentials:\n",
    "os.environ[\"DATABRICKS_HOST\"] = \"DATABRICKS_HOST\"\n",
    "os.environ[\"DATABRICKS_TOKEN\"] = \"DATABRICKS_TOKEN\"\n",
    "os.environ[\"DATABRICKS_CLUSTER_ID\"] = \"DATABRICKS_CLUSTER_ID\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_databricks_env(function):\n",
    "    job_env = {\n",
    "        \"DATABRICKS_HOST\": os.environ[\"DATABRICKS_HOST\"],\n",
    "        \"DATABRICKS_CLUSTER_ID\": os.environ.get(\"DATABRICKS_CLUSTER_ID\"),\n",
    "    }\n",
    "\n",
    "    for name, val in job_env.items():\n",
    "        function.spec.env.append({\"name\": name, \"value\": val})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project_name = \"databricks-runtime-project\"\n",
    "project = mlrun.get_or_create_project(project_name, context=\"./\", user_project=False)\n",
    "\n",
    "secrets = {\"DATABRICKS_TOKEN\": os.environ[\"DATABRICKS_TOKEN\"]}\n",
    "\n",
    "project.set_secrets(secrets)\n",
    "\n",
    "code = \"\"\"\n",
    "def print_kwargs(**kwargs):\n",
    "    print(f\"kwargs: {kwargs}\")\n",
    "\"\"\"\n",
    "\n",
    "function_ref = FunctionReference(\n",
    "    kind=\"databricks\",\n",
    "    code=code,\n",
    "    image=\"mlrun/mlrun\",\n",
    "    name=\"databricks-function\",\n",
    ")\n",
    "\n",
    "function = function_ref.to_function()\n",
    "\n",
    "add_databricks_env(function=function)\n",
    "\n",
    "run = function.run(\n",
    "    handler=\"print_kwargs\",\n",
    "    project=project_name,\n",
    "    params={\n",
    "        \"param1\": \"value1\",\n",
    "        \"param2\": \"value2\",\n",
    "        \"task_parameters\": {\"timeout_minutes\": 15},\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logging a Databricks response as an artifact"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pyspark.sql import SparkSession\n",
    "\n",
    "\n",
    "def main():\n",
    "    df = pd.DataFrame({\"A\": np.random.randint(1, 100, 5), \"B\": np.random.rand(5)})\n",
    "    path = \"/dbfs/path/folder\"\n",
    "    parquet_df_path = f\"{path}/df.parquet\"\n",
    "    csv_df_path = f\"{path}/df.csv\"\n",
    "\n",
    "    if not os.path.exists(path):\n",
    "        os.makedirs(path)\n",
    "\n",
    "    # save df\n",
    "    df.to_parquet(parquet_df_path)\n",
    "    df.to_csv(csv_df_path, index=False)\n",
    "\n",
    "    # log artifact\n",
    "    mlrun_log_artifact(\"parquet_artifact\", parquet_df_path)\n",
    "    mlrun_log_artifact(\"csv_artifact\", csv_df_path)\n",
    "\n",
    "    # spark\n",
    "    spark = SparkSession.builder.appName(\"example\").getOrCreate()\n",
    "    spark_df = spark.createDataFrame(df)\n",
    "\n",
    "    # spark path format:\n",
    "    spark_parquet_path = \"dbfs:///path/folder/spark_df.parquet\"\n",
    "    spark_df.write.mode(\"overwrite\").parquet(spark_parquet_path)\n",
    "    mlrun_log_artifact(\"spark_artifact\", spark_parquet_path)\n",
    "\n",
    "    # an illegal artifact does not raise an error, it logs an error log instead, for example:\n",
    "    # mlrun_log_artifact(\"illegal_artifact\", \"/not_exists_path/illegal_df.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function = mlrun.code_to_function(\n",
    "    name=\"databricks-log_artifact\",\n",
    "    kind=\"databricks\",\n",
    "    project=project_name,\n",
    "    filename=\"./databricks_job.py\",\n",
    "    image=\"mlrun/mlrun\",\n",
    ")\n",
    "add_databricks_env(function=function)\n",
    "run = function.run(\n",
    "    handler=\"main\",\n",
    "    project=project_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project.list_artifacts()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
